{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f969c229-5a8a-44b6-91a3-bba55968b202",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch,imageio,sys,time,ffmpeg,cv2,cmapy,os\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "sys.path.append('..')\n",
    "# from models.sparseCoding import sparseCoding \n",
    "from models.FactorFields import FactorFields \n",
    "\n",
    "from utils import SimpleSampler\n",
    "from dataLoader import dataset_dict\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "device = 'cuda'\n",
    "torch.cuda.set_device(0)\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b9acd258-ab23-489a-93a0-7bc799bbab24",
   "metadata": {},
   "outputs": [],
   "source": [
    "def PSNR(a,b):\n",
    "    if type(a).__module__ == np.__name__:\n",
    "        mse = np.mean((a-b)**2)\n",
    "    else:\n",
    "        mse = torch.mean((a-b)**2).item()\n",
    "    psnr = -10.0 * np.log(mse) / np.log(10.0)\n",
    "    return psnr\n",
    "\n",
    "@torch.no_grad()\n",
    "def eval_img(reso, chunk=10240,target_region=[0.0,0.0,1.0,1.0]):\n",
    "    y = torch.linspace(target_region[0],target_region[2],reso[0])*(H-1)\n",
    "    x = torch.linspace(target_region[1],target_region[3],reso[1])*(W-1)\n",
    "    # y = torch.arange(0, reso[0])\n",
    "    # x = torch.arange(0, reso[1])\n",
    "    yy, xx = torch.meshgrid((y, x), indexing='ij')\n",
    "    res = []\n",
    "    \n",
    "    coordiantes = torch.stack((xx,yy),dim=-1).reshape(-1,2) + 0.5 #/(torch.FloatTensor(reso[::-1])-1)*2-1\n",
    "    # if normalize:\n",
    "    #     coordiantes = coordiantes/torch.tensor([W,H])*2-1\n",
    "    coordiantes = torch.split(coordiantes,chunk,dim=0)\n",
    "    for coordiante in coordiantes:\n",
    "\n",
    "        feats,_ = model.get_coding(coordiante.to(model.device))\n",
    "        y_recon = model.linear_mat(feats)\n",
    "        \n",
    "        res.append(y_recon.cpu())\n",
    "    return torch.cat(res).reshape(reso[0],reso[1],-1)\n",
    "\n",
    "def srgb_to_linear(img):\n",
    "\tlimit = 0.04045\n",
    "\treturn np.where(img > limit, np.power((img + 0.055) / 1.055, 2.4), img / 12.92)\n",
    "\n",
    "def zoom_in_animation(target_region=[0.0,0.0,0.5,0.5], n_frames=150):\n",
    "    shiftment_y = 0.5 - (target_region[0]+target_region[2])/2\n",
    "    shiftment_x = 0.5 - (target_region[1]+target_region[3])/2\n",
    "    scale_y = 1.0 / (target_region[2] - target_region[0])\n",
    "    scale_x = 1.0 / (target_region[3] - target_region[1])\n",
    "    return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9d2a3772-7637-4087-8890-d5e15122ffa3",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_conf = OmegaConf.load('../configs/defaults.yaml')\n",
    "second_conf = OmegaConf.load('../configs/image_intro.yaml')\n",
    "cfg = OmegaConf.merge(base_conf, second_conf)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1beabf7c-0013-45f4-b02d-f80c3f8f7da5",
   "metadata": {},
   "source": [
    "# Please pick one of the follow models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d3936dce-a035-4909-b274-998033e86769",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_mode = 'rgb' # or 'binary'\n",
    "\n",
    "if data_mode == 'binary':\n",
    "    out_dim = 1\n",
    "    cfg.dataset.datadir = '../data/image/cat_occupancy.png'\n",
    "else:\n",
    "    out_dim = 3\n",
    "    cfg.dataset.datadir = '../data/image/cat_rgb.png'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70d91ad3-a00b-4fa3-b059-c149f8c7d7ac",
   "metadata": {},
   "source": [
    "## i. Implicit Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "82a9d865-b912-41c8-97ba-c366eeb62b65",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'occ'\n",
    "cfg.model.coeff_type = 'none'\n",
    "cfg.model.basis_type = 'x'\n",
    "cfg.model.basis_mapping = 'x'\n",
    "cfg.model.num_layers = 8\n",
    "cfg.model.hidden_dim = 256\n",
    "cfg.model.freq_bands=[1.]\n",
    "cfg.model.basis_dims=[1]\n",
    "# cfg.model.basis_resos=[2160]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16006a04-ec82-4e37-9f14-4aa2797b59d0",
   "metadata": {},
   "source": [
    "## ii. NeRF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "cc0f8ca6-b367-4317-9aa5-ffaa2a91d518",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'nerf'\n",
    "cfg.model.coeff_type = 'none'\n",
    "cfg.model.basis_type = 'x'\n",
    "cfg.model.basis_mapping = 'trigonometric'\n",
    "cfg.model.num_layers = 8\n",
    "cfg.model.hidden_dim = 256\n",
    "cfg.model.freq_bands=[1.,2.,4.,8.,16.,32.,64,128,256.,512.]\n",
    "cfg.model.basis_dims=[1,1,1,1,1,1,1,1,1,1]\n",
    "cfg.model.basis_resos=[1024,512,256,128,64,32,16,8,4,2]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2969be53-0a2c-4b79-a404-ac8a8c0c5afd",
   "metadata": {},
   "source": [
    "## iii. Dense Grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2f29657a-6fef-4dba-8f90-6d29418d43a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'dense-grid'\n",
    "cfg.model.coeff_type = 'grid'\n",
    "cfg.model.basis_type = 'none'\n",
    "cfg.model.coeff_reso = 128\n",
    "cfg.model.num_layers = 2\n",
    "cfg.model.hidden_dim = 64\n",
    "cfg.model.basis_dims = [12]\n",
    "cfg.model.basis_resos=[1]\n",
    "cfg.model.T_coeff = 2048000\n",
    "\n",
    "# learning rate\n",
    "cfg.training.lr_small: 0.002\n",
    "cfg.training.lr_large: 0.002"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "acdb4ed5-be78-4ec6-8507-3ded809fe026",
   "metadata": {},
   "source": [
    "## v. Tensor Decomposition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "4d7c0651-5684-4704-9dd7-d61982792e54",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'cp'\n",
    "cfg.model.coeff_type = 'vec'\n",
    "cfg.model.basis_type = 'cp'\n",
    "cfg.model.num_layers = 2\n",
    "cfg.model.hidden_dim = 64\n",
    "cfg.model.basis_dims = [320]\n",
    "cfg.model.freq_bands =  [1.]\n",
    "cfg.model.basis_resos =  [1024]\n",
    "cfg.model.T_basis = 2048000\n",
    "\n",
    "# learning rate\n",
    "cfg.training.lr_small: 0.0002\n",
    "cfg.training.lr_large: 0.002\n",
    "\n",
    "# cfg.model.coef_init: 0.001\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1233b7d2-8da7-4e4e-bab3-3f5aad410fbd",
   "metadata": {},
   "source": [
    "## iv. Hash Grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "de648f30-8643-4372-8e5f-64e51ea5625d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'hash-grid'\n",
    "cfg.model.coeff_type = 'none'\n",
    "cfg.model.basis_type = 'hash'\n",
    "cfg.model.coeff_reso = 0\n",
    "cfg.model.num_layers = 2\n",
    "cfg.model.hidden_dim = 64\n",
    "cfg.model.basis_dims = [2,2,2,2,2,2]\n",
    "cfg.model.freq_bands =  [1.,2.,4.,8.,16.,32.]\n",
    "cfg.model.T_basis = 2048000\n",
    "\n",
    "# learning rate\n",
    "cfg.training.lr_small: 0.0002\n",
    "cfg.training.lr_large: 0.002"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3de570ff-d91b-412d-a838-cb9cb56107ce",
   "metadata": {},
   "source": [
    "## vi. Dictionary Factorization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b81fe62c-3dc2-4a87-a02f-14c788243c10",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset_dict[cfg.dataset.dataset_name]\n",
    "train_dataset = dataset(cfg.dataset, cfg.training.batch_size, split='train', tolinear=False, perscent=0.5)\n",
    "train_dataset.image = train_dataset.image[...,:out_dim]\n",
    "train_loader = DataLoader(train_dataset,\n",
    "              num_workers=8,\n",
    "              persistent_workers=True,\n",
    "              batch_size=None,\n",
    "              pin_memory=True)\n",
    "\n",
    "cfg.model.out_dim = out_dim\n",
    "batch_size = cfg.training.batch_size\n",
    "n_iter = cfg.training.n_iters\n",
    "\n",
    "H,W = train_dataset.HW\n",
    "cfg.dataset.aabb = train_dataset.scene_bbox"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "634c9519-4be6-47b8-b030-616bdc4c6237",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "48 2048000 2048000\n",
      "=====> total parameters:  2047852\n",
      "FactorFields(\n",
      "  (coeffs): ParameterList(  (0): Parameter containing: [torch.float32 of size 1x12x413x413 (GPU 0)])\n",
      "  (linear_mat): MLPMixer(\n",
      "    (backbone): ModuleList(\n",
      "      (0): Linear(in_features=12, out_features=64, bias=True)\n",
      "      (1): Linear(in_features=64, out_features=3, bias=False)\n",
      "    )\n",
      "  )\n",
      ")\n",
      "total parameters:  2047852\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration 09990: loss_dist = 0.00055990 psnr = 32.519: 100%|█| 10000/1\n",
      "ffmpeg version 4.3 Copyright (c) 2000-2020 the FFmpeg developers\n",
      "  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)\n",
      "  configuration: --prefix=/mnt/lustre/geiger/zyu30/.conda/envs/sdfstudio --cc=/opt/conda/conda-bld/ffmpeg_1597178665428/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame\n",
      "  libavutil      56. 51.100 / 56. 51.100\n",
      "  libavcodec     58. 91.100 / 58. 91.100\n",
      "  libavformat    58. 45.100 / 58. 45.100\n",
      "  libavdevice    58. 10.100 / 58. 10.100\n",
      "  libavfilter     7. 85.100 /  7. 85.100\n",
      "  libavresample   4.  0.  0 /  4.  0.  0\n",
      "  libswscale      5.  7.100 /  5.  7.100\n",
      "  libswresample   3.  7.100 /  3.  7.100\n",
      "Input #0, mov,mp4,m4a,3gp,3g2,mj2, from '/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/temp.mp4':\n",
      "  Metadata:\n",
      "    major_brand     : isom\n",
      "    minor_version   : 512\n",
      "    compatible_brands: isomiso2avc1mp41\n",
      "    encoder         : Lavf58.29.100\n",
      "  Duration: 00:00:05.10, start: 0.000000, bitrate: 144048 kb/s\n",
      "    Stream #0:0(und): Video: h264 (High 4:4:4 Predictive) (avc1 / 0x31637661), yuv420p, 2160x2160, 144045 kb/s, 30 fps, 30 tbr, 15360 tbn, 60 tbc (default)\n",
      "    Metadata:\n",
      "      handler_name    : VideoHandler\n",
      "Stream mapping:\n",
      "  Stream #0:0 -> #0:0 (h264 (native) -> mpeg4 (native))\n",
      "Press [q] to stop, [?] for help\n",
      "Output #0, mp4, to '/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/cat_sparse_dense-grid_rgb.mp4':\n",
      "  Metadata:\n",
      "    major_brand     : isom\n",
      "    minor_version   : 512\n",
      "    compatible_brands: isomiso2avc1mp41\n",
      "    encoder         : Lavf58.45.100\n",
      "    Stream #0:0(und): Video: mpeg4 (mp4v / 0x7634706D), yuv420p, 2160x2160, q=2-31, 200 kb/s, 30 fps, 15360 tbn, 30 tbc (default)\n",
      "    Metadata:\n",
      "      handler_name    : VideoHandler\n",
      "      encoder         : Lavc58.91.100 mpeg4\n",
      "    Side data:\n",
      "      cpb: bitrate max/min/avg: 0/0/200000 buffer size: 0 vbv_delay: N/A\n",
      "frame=  153 fps= 73 q=31.0 Lsize=    1931kB time=00:00:05.06 bitrate=3122.4kbits/s speed=2.42x    \n",
      "video:1930kB audio:0kB subtitle:0kB other streams:0kB global headers:0kB muxing overhead: 0.079606%\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(None, None)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = FactorFields(cfg, device)\n",
    "print(model)\n",
    "print('total parameters: ',model.n_parameters())\n",
    "\n",
    "grad_vars = model.get_optparam_groups(lr_small=cfg.training.lr_small,lr_large=cfg.training.lr_large)\n",
    "optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))#\n",
    "\n",
    "H,W = train_dataset.HW\n",
    "\n",
    "imgs = []\n",
    "\n",
    "        \n",
    "psnrs,times = [],[0.0]\n",
    "loss_scale = 1.0\n",
    "lr_factor = 0.1 ** (1 / n_iter)\n",
    "pbar = tqdm(range(n_iter))\n",
    "start = time.time()\n",
    "for (iteration, sample) in zip(pbar,train_loader):\n",
    "    iteration_start = time.time()\n",
    "    loss_scale *= lr_factor\n",
    "\n",
    "    coordiantes, pixel_rgb = sample['xy'], sample['rgb']\n",
    "    feats,coeff = model.get_coding(coordiantes.to(device))\n",
    "    \n",
    "    y_recon = model.linear_mat(feats)\n",
    "    \n",
    "    loss = torch.mean((y_recon.squeeze()-pixel_rgb.squeeze().to(device))**2) \n",
    "    \n",
    "    \n",
    "    psnr = -10.0 * np.log(loss.item()) / np.log(10.0)\n",
    "    psnrs.append(psnr)\n",
    "    times.append(time.time()-start)\n",
    "    \n",
    "    if iteration%10==0:\n",
    "        pbar.set_description(\n",
    "                    f'Iteration {iteration:05d}:'\n",
    "                    + f' loss_dist = {loss.item():.8f}'\n",
    "                    + f' psnr = {psnr:.3f}'\n",
    "                )\n",
    "    \n",
    "    loss = loss * loss_scale\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    iteration_end = time.time()\n",
    "    times.append(times[-1] + iteration_end-iteration_start)\n",
    "    \n",
    "    if iteration%(n_iter//150) == 0 or iteration==n_iter-1:\n",
    "        imgs.append((eval_img(train_dataset.HW).clamp(0,1.)*255).to(torch.uint8))\n",
    "        \n",
    "iteration_time = time.time()-start  \n",
    "    \n",
    "\n",
    "# img = eval_img(train_dataset.HW).clamp(0,1.)\n",
    "# print(PSNR(img,train_dataset.image.view(img.shape)),iteration_time)\n",
    "# plt.figure(figsize=(10, 10))\n",
    "# plt.imshow(img)\n",
    "\n",
    "\n",
    "imageio.mimwrite('/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/temp.mp4', imgs, fps=30, quality=10)\n",
    "# np.savetxt(f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/cat_{model_name}_psnr_{data_mode}.txt',psnrs)\n",
    "# np.savetxt(f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/cat_{model_name}_time_{data_mode}.txt',times)\n",
    "\n",
    "stream = ffmpeg.input('/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/temp.mp4')\n",
    "stream = ffmpeg.output(stream, f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/cat_sparse_{model_name}_{data_mode}.mp4')\n",
    "ffmpeg.run(stream,overwrite_output=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb303268-1a2d-4e87-ba62-616546907dff",
   "metadata": {},
   "source": [
    "# vis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cb24b10e-8fa8-4e5b-954b-0be5fcd336df",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "48 2048000 2048000\n",
      "=====> total parameters:  2047852\n",
      "FactorFields(\n",
      "  (coeffs): ParameterList(  (0): Parameter containing: [torch.float32 of size 1x12x413x413 (GPU 0)])\n",
      "  (linear_mat): MLPMixer(\n",
      "    (backbone): ModuleList(\n",
      "      (0): Linear(in_features=12, out_features=64, bias=True)\n",
      "      (1): Linear(in_features=64, out_features=3, bias=False)\n",
      "    )\n",
      "  )\n",
      ")\n",
      "total parameters:  2047852\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/scratch_local/zyu30-52677/tmp/ipykernel_1781387/912974985.py:14: RuntimeWarning: invalid value encountered in true_divide\n",
      "  feat = (feat-np.min(feat))/(np.max(feat) - np.min(feat))\n",
      "Iteration 09990: loss_dist = 0.00056620 psnr = 32.470: 100%|█| 10000/1\n",
      "IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (413, 413) to (416, 416) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).\n",
      "[swscaler @ 0x5d4ad00] Warning: data is not aligned! This can lead to a speed loss\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800000; text-decoration-color: #800000\">╭─────────────────────────────── </span><span style=\"color: #800000; text-decoration-color: #800000; font-weight: bold\">Traceback </span><span style=\"color: #bf7f7f; text-decoration-color: #bf7f7f; font-weight: bold\">(most recent call last)</span><span style=\"color: #800000; text-decoration-color: #800000\"> ────────────────────────────────╮</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">&lt;module&gt;</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">78</span>                                                                                   <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>                                                                                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">75 │   </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">for</span> i <span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">in</span> <span style=\"color: #00ffff; text-decoration-color: #00ffff\">range</span>(model.coeffs[<span style=\"color: #0000ff; text-decoration-color: #0000ff\">0</span>].shape[<span style=\"color: #0000ff; text-decoration-color: #0000ff\">1</span>]):                                               <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">76 │   │   </span>imageio.mimwrite(<span style=\"color: #808000; text-decoration-color: #808000\">'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/im</span>    <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">77 │   │   </span>                                                                                    <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>78 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   </span>os.makedirs(<span style=\"color: #808000; text-decoration-color: #808000\">f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/</span>    <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">79 │   │   </span>stream = ffmpeg.input(<span style=\"color: #808000; text-decoration-color: #808000\">'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/vid</span>    <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">80 │   │   </span>stream = ffmpeg.output(stream, <span style=\"color: #808000; text-decoration-color: #808000\">f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/Factor</span>    <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">81 │   │   </span>ffmpeg.run(stream,overwrite_output=<span style=\"color: #0000ff; text-decoration-color: #0000ff\">True</span>)                                            <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">╰──────────────────────────────────────────────────────────────────────────────────────────────────╯</span>\n",
       "<span style=\"color: #ff0000; text-decoration-color: #ff0000; font-weight: bold\">NameError: </span>name <span style=\"color: #008000; text-decoration-color: #008000\">'os'</span> is not defined\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n",
       "\u001b[31m│\u001b[0m in \u001b[92m<module>\u001b[0m:\u001b[94m78\u001b[0m                                                                                   \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m                                                                                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m75 \u001b[0m\u001b[2m│   \u001b[0m\u001b[94mfor\u001b[0m i \u001b[95min\u001b[0m \u001b[96mrange\u001b[0m(model.coeffs[\u001b[94m0\u001b[0m].shape[\u001b[94m1\u001b[0m]):                                               \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m76 \u001b[0m\u001b[2m│   │   \u001b[0mimageio.mimwrite(\u001b[33m'\u001b[0m\u001b[33m/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/im\u001b[0m    \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m77 \u001b[0m\u001b[2m│   │   \u001b[0m                                                                                    \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m78 \u001b[2m│   │   \u001b[0mos.makedirs(\u001b[33mf\u001b[0m\u001b[33m'\u001b[0m\u001b[33m/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/\u001b[0m    \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m79 \u001b[0m\u001b[2m│   │   \u001b[0mstream = ffmpeg.input(\u001b[33m'\u001b[0m\u001b[33m/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/vid\u001b[0m    \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m80 \u001b[0m\u001b[2m│   │   \u001b[0mstream = ffmpeg.output(stream, \u001b[33mf\u001b[0m\u001b[33m'\u001b[0m\u001b[33m/mnt/qb/home/geiger/zyu30/Projects/Anpei/Factor\u001b[0m    \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m81 \u001b[0m\u001b[2m│   │   \u001b[0mffmpeg.run(stream,overwrite_output=\u001b[94mTrue\u001b[0m)                                            \u001b[31m│\u001b[0m\n",
       "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n",
       "\u001b[1;91mNameError: \u001b[0mname \u001b[32m'os'\u001b[0m is not defined\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = FactorFields(cfg, device)\n",
    "print(model)\n",
    "print('total parameters: ',model.n_parameters())\n",
    "\n",
    "grad_vars = model.get_optparam_groups(lr_small=cfg.training.lr_small,lr_large=cfg.training.lr_large)\n",
    "optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))#\n",
    "\n",
    "H,W = train_dataset.HW\n",
    "\n",
    "imgs = []\n",
    "if model_name == 'dense-grid':\n",
    "    for c in range(model.coeffs[0].shape[1]):\n",
    "        feat = model.coeffs[0][0,c].cpu().detach().numpy()\n",
    "        feat = (feat-np.min(feat))/(np.max(feat) - np.min(feat))\n",
    "        img = cmapy.colorize((feat * 255).astype('uint8'), 'coolwarm')\n",
    "        imgs.append(img)\n",
    "\n",
    "psnrs,times = [],[0.0]\n",
    "loss_scale = 1.0\n",
    "lr_factor = 0.1 ** (1 / n_iter)\n",
    "pbar = tqdm(range(n_iter))\n",
    "start = time.time()\n",
    "for (iteration, sample) in zip(pbar,train_loader):\n",
    "    iteration_start = time.time()\n",
    "    loss_scale *= lr_factor\n",
    "\n",
    "    coordiantes, pixel_rgb = sample['xy'], sample['rgb']\n",
    "    feats,coeff = model.get_coding(coordiantes.to(device))\n",
    "    \n",
    "    y_recon = model.linear_mat(feats)\n",
    "    \n",
    "    loss = torch.mean((y_recon.squeeze()-pixel_rgb.squeeze().to(device))**2) \n",
    "    \n",
    "    \n",
    "    psnr = -10.0 * np.log(loss.item()) / np.log(10.0)\n",
    "    psnrs.append(psnr)\n",
    "    times.append(time.time()-start)\n",
    "    \n",
    "    if iteration%10==0:\n",
    "        pbar.set_description(\n",
    "                    f'Iteration {iteration:05d}:'\n",
    "                    + f' loss_dist = {loss.item():.8f}'\n",
    "                    + f' psnr = {psnr:.3f}'\n",
    "                )\n",
    "    \n",
    "    loss = loss * loss_scale\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    iteration_end = time.time()\n",
    "    times.append(times[-1] + iteration_end-iteration_start)\n",
    "    \n",
    "    if iteration%(n_iter//150) == 0 or iteration==n_iter-1:\n",
    "        if model_name == 'dense-grid':\n",
    "            for c in range(model.coeffs[0].shape[1]):\n",
    "                feat = model.coeffs[0][0,c].cpu().detach().numpy()\n",
    "                feat = (feat-np.min(feat))/(np.max(feat) - np.min(feat))\n",
    "                img = cmapy.colorize((feat * 255).astype('uint8'), 'coolwarm')\n",
    "                imgs.append(img)\n",
    "        elif model_name == 'cp':\n",
    "             for c in range(model.coeffs[0].shape[1]):\n",
    "                for item in [model.coeffs[0],model.basises[0]]:\n",
    "                    feat = item[0,c].cpu().detach().numpy()\n",
    "                    feat = (feat-np.min(feat))/(np.max(feat) - np.min(feat))\n",
    "                    feat = cv2.resize(feat,(64,feat.shape[0]))\n",
    "                    img = cmapy.colorize((feat * 255).astype('uint8'), 'coolwarm')\n",
    "                    imgs.append(img)\n",
    "        \n",
    "        \n",
    "iteration_time = time.time()-start  \n",
    "    \n",
    "if model_name == 'dense-grid':\n",
    "    imgs = np.stack(imgs).reshape(-1,model.coeffs[0].shape[1],*imgs[0].shape)\n",
    "    for i in range(model.coeffs[0].shape[1]):\n",
    "        imageio.mimwrite('/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/temp.mp4', imgs[:,i], fps=30, quality=10)\n",
    "\n",
    "        os.makedirs(f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/vis_sparse_cat_{model_name}_{data_mode}',exist_ok=True)\n",
    "        stream = ffmpeg.input('/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/temp.mp4')\n",
    "        stream = ffmpeg.output(stream, f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/vis_sparse_cat_{model_name}_{data_mode}/{i:02d}.mp4')\n",
    "        ffmpeg.run(stream,overwrite_output=True)\n",
    "elif model_name == 'cp':\n",
    "    vis_coef, vis_basis = imgs[0::2], imgs[1::2]\n",
    "    vis_coef = np.stack(vis_coef).reshape(-1,model.coeffs[0].shape[1],*vis_coef[0].shape)\n",
    "    vis_basis = np.stack(vis_basis).reshape(-1,model.coeffs[0].shape[1],*vis_basis[0].shape)\n",
    "    for (name,item) in zip(['coef','basis'],[vis_coef,vis_basis]):\n",
    "        for i in range(item.shape[1]):\n",
    "            imageio.mimwrite('/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/temp.mp4', item[:,i], fps=30, quality=10)\n",
    "\n",
    "            os.makedirs(f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/vis_cat_{model_name}_{data_mode}',exist_ok=True)\n",
    "            stream = ffmpeg.input('/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/temp.mp4')\n",
    "            stream = ffmpeg.output(stream, f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/vis_cat_{model_name}_{data_mode}/{name}-{i:02d}.mp4')\n",
    "            ffmpeg.run(stream,overwrite_output=True)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7bbeb8b6-fe5f-41ad-9f99-b9d85bca075a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c76fef7a-bd96-4629-99d0-c5077eb8e1db",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ffmpeg version 4.3 Copyright (c) 2000-2020 the FFmpeg developers\n",
      "  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)\n",
      "  configuration: --prefix=/mnt/lustre/geiger/zyu30/.conda/envs/sdfstudio --cc=/opt/conda/conda-bld/ffmpeg_1597178665428/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame\n",
      "  libavutil      56. 51.100 / 56. 51.100\n",
      "  libavcodec     58. 91.100 / 58. 91.100\n",
      "  libavformat    58. 45.100 / 58. 45.100\n",
      "  libavdevice    58. 10.100 / 58. 10.100\n",
      "  libavfilter     7. 85.100 /  7. 85.100\n",
      "  libavresample   4.  0.  0 /  4.  0.  0\n",
      "  libswscale      5.  7.100 /  5.  7.100\n",
      "  libswresample   3.  7.100 /  3.  7.100\n",
      "Input #0, mov,mp4,m4a,3gp,3g2,mj2, from '/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/temp.mp4':\n",
      "  Metadata:\n",
      "    major_brand     : isom\n",
      "    minor_version   : 512\n",
      "    compatible_brands: isomiso2avc1mp41\n",
      "    encoder         : Lavf58.29.100\n",
      "  Duration: 00:00:05.13, start: 0.000000, bitrate: 10167 kb/s\n",
      "    Stream #0:0(und): Video: h264 (High 4:4:4 Predictive) (avc1 / 0x31637661), yuv420p, 416x416, 10165 kb/s, 30 fps, 30 tbr, 15360 tbn, 60 tbc (default)\n",
      "    Metadata:\n",
      "      handler_name    : VideoHandler\n",
      "Stream mapping:\n",
      "  Stream #0:0 -> #0:0 (h264 (native) -> mpeg4 (native))\n",
      "Press [q] to stop, [?] for help\n",
      "Output #0, mp4, to '/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/vis_sparse_cat_dense-grid_rgb/00.mp4':\n",
      "  Metadata:\n",
      "    major_brand     : isom\n",
      "    minor_version   : 512\n",
      "    compatible_brands: isomiso2avc1mp41\n",
      "    encoder         : Lavf58.45.100\n",
      "    Stream #0:0(und): Video: mpeg4 (mp4v / 0x7634706D), yuv420p, 416x416, q=2-31, 200 kb/s, 30 fps, 15360 tbn, 30 tbc (default)\n",
      "    Metadata:\n",
      "      handler_name    : VideoHandler\n",
      "      encoder         : Lavc58.91.100 mpeg4\n",
      "    Side data:\n",
      "      cpb: bitrate max/min/avg: 0/0/200000 buffer size: 0 vbv_delay: N/A\n",
      "frame=  154 fps=0.0 q=18.3 Lsize=     434kB time=00:00:05.10 bitrate= 696.7kbits/s speed=31.9x    \n",
      "video:432kB audio:0kB subtitle:0kB other streams:0kB global headers:0kB muxing overhead: 0.346337%\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(None, None)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "        os.makedirs(f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/vis_sparse_cat_{model_name}_{data_mode}',exist_ok=True)\n",
    "        stream = ffmpeg.input('/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/temp.mp4')\n",
    "        stream = ffmpeg.output(stream, f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/FactorFields/video/image/vis_sparse_cat_{model_name}_{data_mode}/{i:02d}.mp4')\n",
    "        ffmpeg.run(stream,overwrite_output=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "455beb54-8d45-45e1-a066-898756fd53ff",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "factorfield",
   "language": "python",
   "name": "factorfield"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
